{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5b23b9a2",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "# Region and Artifact Detection Tutorial\n",
    "\n",
    "This tutorial walks you to perform patch based region and artifact detection using KRONOS. The workflow involves extracting token-specific embeddings from input images and classifying each token.\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "To follow this tutorial, ensure you have the following data prepared:\n",
    "\n",
    " - A set of annotated 256x256 patches extracted from multiplex images and saved as .h5 format under patches folder in your project directory. \n",
    " - **patch_annotations.csv** cantaining two columns `patch_name` and `label`.\n",
    " - **marker_info_with_metadata.csv**: A CSV file containing:\n",
    "    - `channel_id`: Identifier for the image channel.\n",
    "    - `marker_name`: Name of the marker (e.g., DAPI, CD20).\n",
    "    - `marker_id`: Unique ID for the marker.\n",
    "    - `marker_mean`: Mean intensity value for normalization.\n",
    "    - `marker_std`: Standard deviation for normalization.\n",
    "\n",
    " **Note**: Refer to the **[1-Data-Download-And-Preprocessing](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/1%20-%20Data-Download-And-Preprocessing.ipynb)** tutorial to prepare data and refer to **[3 - Patch-phenotyping](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/3%20-%20Patch-phenotyping.ipynb)** tutorial to exptract patches from multiplex images."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbe57459",
   "metadata": {},
   "source": [
    "## Step 1: Experiment Configuration\n",
    "\n",
    "In this section, we define the configuration and hyperparameters required for the workflow. Ensure your dataset in the project directory is organized as follows:\n",
    "\n",
    "### Dataset Structure\n",
    "- **`dataset/multiplex_images/`**: Contains the multiplex image files (e.g., `.tiff`).\n",
    "- **`dataset/marker_info_with_metadata.csv`**: A CSV file containing marker metadata.\n",
    "- **`dataset/patch_annotations.csv`**: A CSV file containing ground truth annotations.\n",
    "\n",
    "### Output Directories\n",
    "The following directories in the project directory will be generated during the workflow:\n",
    "- **`patches/`**: Stores the extracted cell-centered patches in `.h5` format.\n",
    "- **`features/`**: Stores the extracted features in `.npy` format.\n",
    "- **`folds/`**: Contains cross-validation folds for training, validation, and testing.\n",
    "- **`results/`**: Stores the results for each fold and aggregated metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32017d0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import RegionArtifactClassification \n",
    "\n",
    "# Define the root directory for the project\n",
    "project_dir = \"path/to/project/dir/\"  # Replace with your actual root directory\n",
    "\n",
    "# Configuration dictionary containing all parameters for the pipeline\n",
    "config = {\n",
    "    # Dataset-related parameters, feel free to modify these paths according to your dataset structure\n",
    "    \"image_dir_path\": f\"{project_dir}/dataset/multiplex_images/\",  # Path to the multiplex image file\n",
    "    \"cell_mask_dir_path\": f\"{project_dir}/dataset/cell_masks/\",  # Path to the cell mask file\n",
    "    \"marker_info_with_metadata_csv_path\": f\"{project_dir}/dataset/marker_info_with_metadata.csv\",  # Path to the marker metadata CSV file\n",
    "    \"gt_csv_path\": f\"{project_dir}/dataset/cell_annotations.csv\",  # Path to the ground truth annotations CSV file\n",
    "    \n",
    "    # Output directories for intermediate and final results\n",
    "    \"patch_dir\": f\"{project_dir}/patches/\",  # Directory to save extracted patches\n",
    "    \"feature_dir\": f\"{project_dir}/features/\",  # Directory to save extracted features\n",
    "    \"fold_dir\": f\"{project_dir}/folds/\",  # Directory to save cross-validation folds\n",
    "    \"result_dir\": f\"{project_dir}/results/\",  # Directory to save final results\n",
    "    \n",
    "    # Model-related parameters\n",
    "    \"checkpoint_path\": \"hf_hub:MahmoodLab/kronos\",  # Path to the pre-trained model checkpoint (local or Hugging Face Hub)\n",
    "    \"hf_auth_token\": None,  # Authentication token for Hugging Face Hub (if checkpoint is from the Hugging Face Hub)\n",
    "    \"cache_dir\": f\"{project_dir}/models/\",  # Directory to cache KRONOS model if downloading model from Hugging Face Hub\n",
    "    \"model_type\": \"vits16\",  # Type of pre-trained model to use (e.g., vits16)\n",
    "    \"token_overlap\": True,  # Whether to use token overlap during feature extraction\n",
    "\n",
    "    # Patch extraction parameters\n",
    "    \"patch_size\": 32,  # Size of the patches to extract (e.g., 64x64 pixels)\n",
    "    \"patch_stride\": 16,  # Stride for patch extraction (e.g., 32 pixels)\n",
    "    \"token_size\": 16,  # Size of the tokens for the model (e.g., 16x16 pixels)\n",
    "    \"token_stride\": 8,  # Stride for token extraction (e.g., 16 or 8 pixels)\n",
    "\n",
    "    # Feature extraction parameters\n",
    "    \"marker_list\": ['DAPI'],  # List of markers to be used for patch phenotyping\n",
    "    \"marker_max_values\": 65535.0,  # Maximum possible value marker image, depends on image type (e.g., 255 for uint8, 65535 for uint16)\n",
    "    \"patch_batch_size\": 16,  # Batch size for patch-based data loading\n",
    "    \"num_workers\": 4,  # Number of workers for data loading\n",
    "\n",
    "    # Logistic regression parameters\n",
    "    \"feature_batch_size\": 256,  # Batch size for feature-based data loading\n",
    "    \"n_trials\": 25,  # Number of trials for Optuna hyperparameter optimization\n",
    "    \"C_low\": 1e-10,  # Lower bound for the regularization parameter (C) in logistic regression\n",
    "    \"C_high\": 1e5,  # Upper bound for the regularization parameter (C) in logistic regression\n",
    "    \"max_iter\": 10000,  # Maximum number of iterations for logistic regression training\n",
    "}\n",
    "obj = RegionArtifactClassification(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "267c324f",
   "metadata": {},
   "source": [
    "## Step 2: Dummy Dataset Generation (Optional) \n",
    "\n",
    "If you are not sure about exact input format of the data then run following script to generate .h5 patches that cantain random values but will be useful to run this tutorial end-to-end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bbc85da",
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.generate_dummy_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa66950b",
   "metadata": {},
   "source": [
    "## Step 3: Feature Extraction\n",
    "\n",
    "Extract features from patches using a pre-trained model. This step processes patches of multiplex images, applies a pre-trained model to extract meaningful feature vectors, and saves them for downstream analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83505db8",
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.feature_extraction()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aba4067",
   "metadata": {},
   "source": [
    "## Step 4: Generate Data Folds\n",
    "\n",
    "Generate data folds for training, validation, and testing. This tutorial split the patches listed in the annotation.csv into four equal sets based on their order in the list. Each set represents one fold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "940845f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.folds_generation()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb08bdbf",
   "metadata": {},
   "source": [
    "## Step 5: Train Region or Artifact Detection Model\n",
    "\n",
    "Train a logistic regression model on the training data, validate it on the validation data, and evaluate it on the test data. Results for each fold are saved in the output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe91191",
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.train_classification_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9db0253",
   "metadata": {},
   "source": [
    "## Step 6: Evaluate Model\n",
    "\n",
    "Read the test results for each fold, compute evaluation metrics (F1-Score, Balanced Accuracy, Average Precision, and ROC AUC), and calculate the average and standard deviation across all folds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dc6ca25",
   "metadata": {},
   "outputs": [],
   "source": [
    "obj.eval_classification_model()\n",
    "obj.calculate_results()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rapids-24.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
